package com.xuecheng.service.Impl;

import com.xuecheng.dto.CourseCategoryTreeDto;
import com.xuecheng.mapper.CourseCategoryMapper;
import com.xuecheng.service.CourseCategoryService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Service
public class CourseCategoryServiceImpl implements CourseCategoryService {

    @Autowired
    CourseCategoryMapper courseCategoryMapper;
    /**
     * 课程分类树形结构查询
     *
     * @param id
     * @return
     */
    @Override
    public List<CourseCategoryTreeDto> queryTreeNodes(String id) {

        //查询数据库
        List<CourseCategoryTreeDto> courseCategoryTreeDtos = courseCategoryMapper.selectTreeNodes(id);

        //将list转map，以备使用，排除根节点
        Map<String, CourseCategoryTreeDto> mapTemp = courseCategoryTreeDtos.stream()
                .filter(item -> !id.equals(item.getId()))
                .collect(Collectors.toMap(key -> key.getId(), value -> value, (key1, key2) -> key2));

        //最终返回的list
        List<CourseCategoryTreeDto> categoryTreeDtos = new ArrayList<>();

        //依次遍历每个元素，排除根节点
        courseCategoryTreeDtos.stream().filter(item ->!id.equals(item.getId())).forEach(item ->{
            if(item.getParentid().equals(id)){
                categoryTreeDtos.add(item);
            }

            //找到当前节点的父节点
            CourseCategoryTreeDto courseCategoryTreeDto = mapTemp.get(item.getParentid());
            if(courseCategoryTreeDto != null){
                if(courseCategoryTreeDto.getChildrenTreeNodes() == null){
                    courseCategoryTreeDto.setChildrenTreeNodes(new ArrayList<CourseCategoryTreeDto>());
                }

                //下边开始往childrenTreeNodes属性中放子节点
                courseCategoryTreeDto.getChildrenTreeNodes().add(item);
            }
        });
        return categoryTreeDtos;
    }
}
